In [1]:
import mxnet as mx
import numpy as np
mx.random.seed(1)
In [2]:
ctx = mx.gpu()
In [3]:
def pure_batch_norm(X, gamma, beta, eps = 1e-5):
if len(X.shape) not in (2, 4):
raise ValueError('only supports dense or 2dconv')
# dense
if len(X.shape) == 2:
# mini-batch mean
mean = mx.nd.mean(X, axis=0)
# mini-batch variance
variance = mx.nd.mean((X - mean) ** 2, axis=0)
# normalize
X_hat = (X - mean) * 1.0 / mx.nd.sqrt(variance + eps)
# scale and shift
out = gamma * X_hat + beta
# 2d conv
elif len(X.shape) == 4:
# extract the dimensions
N, C, H, W = X.shape
# mini-batch mean
mean = mx.nd.mean(X, axis=(0, 2, 3))
# mini-batch variance
variance = mx.nd.mean((X - mean.reshape((1, C, 1, 1))) ** 2, axis=(0, 2, 3))
# normalize
X_hat = (X - mean.reshape((1, C, 1, 1))) * 1.0 / mx.nd.sqrt(variance.reshape((1, C, 1, 1)) + eps)
# scale and shift
out = gamma.reshape((1, C, 1, 1)) * X_hat + beta.reshape((1, C, 1, 1))
return out
In [4]:
A = mx.nd.array([1, 2, 3, 6, 5, 7], ctx=ctx).reshape((3, 2))
A
Out[4]:
In [5]:
pure_batch_norm(X=A,
gamma=mx.nd.array([1,1], ctx=ctx),
beta=mx.nd.array([0,0], ctx=ctx))
Out[5]:
In [6]:
B = mx.nd.array([1,6,5,7,4,3,2,5,6,3,2,4,5,3,2,5,6], ctx=ctx).reshape((2, 2, 2, 2))
B
Out[6]:
In [7]:
# 1st sample, 1st layer
B[0, 0, :, :]
Out[7]:
In [8]:
# 1st sample, 2nd layer
B[0, 1, :, :]
Out[8]:
In [9]:
pure_batch_norm(X=B,
gamma=mx.nd.array([1,1], ctx=ctx),
beta=mx.nd.array([0,0], ctx=ctx))
Out[9]:
In [10]:
B_normalized = pure_batch_norm(X=B,
gamma=mx.nd.array([1,1], ctx=ctx),
beta=mx.nd.array([0,0], ctx=ctx))
In [11]:
B_normalized[0, 0, :, :]
Out[11]: